{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/decision-labs/geobase-ai.js/blob/automate-task-selection/task-classifier/explore-collab-notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JIkmn3eFeznp",
        "outputId": "cf854887-002e-42cc-b1b8-cd38c92dc4e9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (4.48.3)\n",
            "Collecting datasets\n",
            "  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)\n",
            "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.5.1+cu124)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (1.6.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.17.0)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.28.1)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (1.26.4)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.2)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.11.6)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n",
            "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.21.0)\n",
            "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.5.3)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.67.1)\n",
            "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (18.1.0)\n",
            "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n",
            "  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from datasets) (2.2.2)\n",
            "Collecting xxhash (from datasets)\n",
            "  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n",
            "Collecting multiprocess<0.70.17 (from datasets)\n",
            "  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)\n",
            "Requirement already satisfied: fsspec<=2024.12.0,>=2023.1.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets) (2024.10.0)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from datasets) (3.11.13)\n",
            "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.12.2)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.5)\n",
            "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n",
            "  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n",
            "  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n",
            "  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n",
            "  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n",
            "  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n",
            "  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n",
            "  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n",
            "  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n",
            "  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
            "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n",
            "  Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.0)\n",
            "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
            "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.13.1)\n",
            "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.4.2)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (3.5.0)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (2.4.6)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.3.2)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (25.1.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.5.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (6.1.0)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (0.3.0)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.18.3)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.4.1)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.3.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2025.1.31)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2025.1)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2025.1)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
            "Downloading datasets-3.3.2-py3-none-any.whl (485 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m485.4/485.4 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m37.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m8.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m143.5/143.5 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.8/194.8 kB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: xxhash, nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, dill, nvidia-cusparse-cu12, nvidia-cudnn-cu12, multiprocess, nvidia-cusolver-cu12, datasets\n",
            "  Attempting uninstall: nvidia-nvjitlink-cu12\n",
            "    Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n",
            "    Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-curand-cu12\n",
            "    Found existing installation: nvidia-curand-cu12 10.3.6.82\n",
            "    Uninstalling nvidia-curand-cu12-10.3.6.82:\n",
            "      Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n",
            "  Attempting uninstall: nvidia-cufft-cu12\n",
            "    Found existing installation: nvidia-cufft-cu12 11.2.3.61\n",
            "    Uninstalling nvidia-cufft-cu12-11.2.3.61:\n",
            "      Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n",
            "  Attempting uninstall: nvidia-cuda-runtime-cu12\n",
            "    Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n",
            "    Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
            "    Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n",
            "    Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-cuda-cupti-cu12\n",
            "    Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n",
            "    Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-cublas-cu12\n",
            "    Found existing installation: nvidia-cublas-cu12 12.5.3.2\n",
            "    Uninstalling nvidia-cublas-cu12-12.5.3.2:\n",
            "      Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n",
            "  Attempting uninstall: nvidia-cusparse-cu12\n",
            "    Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n",
            "    Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n",
            "      Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n",
            "  Attempting uninstall: nvidia-cudnn-cu12\n",
            "    Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n",
            "    Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n",
            "      Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n",
            "  Attempting uninstall: nvidia-cusolver-cu12\n",
            "    Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n",
            "    Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n",
            "      Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n",
            "Successfully installed datasets-3.3.2 dill-0.3.8 multiprocess-0.70.16 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 xxhash-3.5.0\n"
          ]
        }
      ],
      "source": [
        "!pip install transformers datasets torch scikit-learn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YDOyCLNse6wG",
        "outputId": "ea9517d2-0296-42bf-c65f-d8a441c7da83"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nreimers/BERT-Tiny_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        }
      ],
      "source": [
        "from transformers import BertForSequenceClassification, BertTokenizer\n",
        "\n",
        "# Load tokenizer\n",
        "model_name = \"nreimers/BERT-Tiny_L-2_H-128_A-2\"\n",
        "tokenizer = BertTokenizer.from_pretrained(model_name)\n",
        "\n",
        "# Load model with three labels\n",
        "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=4)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x_Zpsh-gfB3M"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "# Load dataset from CSV\n",
        "dataset = load_dataset(\"geobase/geo-task-classifier\")\n",
        "\n",
        "# Tokenization function\n",
        "def tokenize_function(examples):\n",
        "    return tokenizer(examples[\"query\"], padding=\"max_length\", truncation=True)\n",
        "\n",
        "# Apply tokenization\n",
        "tokenized_datasets = dataset.map(tokenize_function, batched=True)\n",
        "\n",
        "# # Remove text column\n",
        "tokenized_datasets = tokenized_datasets.remove_columns([\"query\"])\n",
        "tokenized_datasets.set_format(\"torch\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6Cwn_7JjfoES"
      },
      "outputs": [],
      "source": [
        "from torch.utils.data import DataLoader\n",
        "\n",
        "train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, batch_size=16)\n",
        "test_dataloader = DataLoader(tokenized_datasets[\"test\"], batch_size=16)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yf6_N0sFuF6f",
        "outputId": "3e86c3c0-dd4d-4f07-b208-38a95e86c605"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Dataset({\n",
            "    features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
            "    num_rows: 81\n",
            "})\n"
          ]
        }
      ],
      "source": [
        "print(tokenized_datasets[\"train\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "x-mXPY4WhY5q",
        "outputId": "ade5b04a-3f27-490a-a479-77ac34a61517"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/transformers/optimization.py:591: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "source": [
        "from transformers import AdamW\n",
        "\n",
        "# Define optimizer\n",
        "optimizer = AdamW(model.parameters(), lr=5e-5)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5SaucZHzhb8R",
        "outputId": "853db48d-36d2-4bcc-ea9b-c5f8ad4a5900"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1 - Average Loss: 1.0820\n",
            "Epoch 2 - Average Loss: 1.0996\n",
            "Epoch 3 - Average Loss: 1.0634\n",
            "Epoch 4 - Average Loss: 1.0841\n",
            "Epoch 5 - Average Loss: 1.0655\n",
            "Epoch 6 - Average Loss: 1.0416\n",
            "Epoch 7 - Average Loss: 1.0949\n",
            "Epoch 8 - Average Loss: 1.0884\n",
            "Epoch 9 - Average Loss: 1.0815\n",
            "Epoch 10 - Average Loss: 1.0882\n",
            "Epoch 11 - Average Loss: 1.1048\n",
            "Epoch 12 - Average Loss: 1.0900\n",
            "Epoch 13 - Average Loss: 1.1229\n",
            "Epoch 14 - Average Loss: 1.0679\n",
            "Epoch 15 - Average Loss: 1.0813\n",
            "Epoch 16 - Average Loss: 1.1099\n",
            "Epoch 17 - Average Loss: 1.0915\n",
            "Epoch 18 - Average Loss: 1.0805\n",
            "Epoch 19 - Average Loss: 1.0834\n",
            "Epoch 20 - Average Loss: 1.0857\n",
            "Epoch 21 - Average Loss: 1.1165\n",
            "Epoch 22 - Average Loss: 1.1017\n",
            "Epoch 23 - Average Loss: 1.0916\n",
            "Epoch 24 - Average Loss: 1.0981\n",
            "Epoch 25 - Average Loss: 1.0676\n",
            "Epoch 26 - Average Loss: 1.0628\n",
            "Epoch 27 - Average Loss: 1.0558\n",
            "Epoch 28 - Average Loss: 1.0942\n",
            "Epoch 29 - Average Loss: 1.0843\n",
            "Epoch 30 - Average Loss: 1.1024\n",
            "Epoch 31 - Average Loss: 1.0833\n",
            "Epoch 32 - Average Loss: 1.0726\n",
            "Epoch 33 - Average Loss: 1.0794\n",
            "Epoch 34 - Average Loss: 1.0842\n",
            "Epoch 35 - Average Loss: 1.0842\n",
            "Epoch 36 - Average Loss: 1.0597\n",
            "Epoch 37 - Average Loss: 1.0438\n",
            "Epoch 38 - Average Loss: 1.0743\n",
            "Epoch 39 - Average Loss: 1.0914\n",
            "Epoch 40 - Average Loss: 1.0798\n",
            "Epoch 41 - Average Loss: 1.0966\n",
            "Epoch 42 - Average Loss: 1.0714\n",
            "Epoch 43 - Average Loss: 1.0727\n",
            "Epoch 44 - Average Loss: 1.0825\n",
            "Epoch 45 - Average Loss: 1.0792\n",
            "Epoch 46 - Average Loss: 1.0821\n",
            "Epoch 47 - Average Loss: 1.0876\n",
            "Epoch 48 - Average Loss: 1.1118\n",
            "Epoch 49 - Average Loss: 1.1015\n",
            "Epoch 50 - Average Loss: 1.0811\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "from transformers import get_scheduler\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "model.to(device)\n",
        "\n",
        "num_training_steps = len(train_dataloader) * 3  # 3 epochs\n",
        "lr_scheduler = get_scheduler(\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)\n",
        "\n",
        "# Training Loop\n",
        "epochs = 50\n",
        "for epoch in range(epochs):\n",
        "    model.train()\n",
        "    total_loss = 0\n",
        "    for batch in train_dataloader:\n",
        "        batch = {k: v.to(device) for k, v in batch.items()}\n",
        "        outputs = model(**batch)\n",
        "        loss = outputs.loss\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        lr_scheduler.step()\n",
        "        optimizer.zero_grad()\n",
        "        total_loss += loss.item()\n",
        "\n",
        "    avg_loss = total_loss / len(train_dataloader)\n",
        "    print(f\"Epoch {epoch+1} - Average Loss: {avg_loss:.4f}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vUAU5vNTheFy",
        "outputId": "0deb5438-b29e-434c-9c32-216284c48ef6"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "('bert_tiny_finetuned/tokenizer_config.json',\n",
              " 'bert_tiny_finetuned/special_tokens_map.json',\n",
              " 'bert_tiny_finetuned/vocab.txt',\n",
              " 'bert_tiny_finetuned/added_tokens.json')"
            ]
          },
          "execution_count": 32,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Save the trained model\n",
        "model.save_pretrained(\"bert_tiny_finetuned\")\n",
        "\n",
        "# Save the tokenizer as well\n",
        "tokenizer.save_pretrained(\"bert_tiny_finetuned\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "hn_wvhDRvCtc",
        "outputId": "635994e8-5560-49bf-9309-a8088dbe3fbb"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
            "  warnings.warn(\n",
            "<ipython-input-15-a48d7431edb5>:23: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
            "  trainer = Trainer(\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='401' max='550' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [401/550 05:08 < 01:55, 1.29 it/s, Epoch 36.36/50]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.328542</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.314359</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.297669</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.268253</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.247607</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.225565</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.196141</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.185258</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.162513</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.128378</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>11</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.103570</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>12</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.081249</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>13</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.057329</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>14</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.032270</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.009215</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>16</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.991285</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>17</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.972464</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>18</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.949159</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>19</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.926883</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>20</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.906142</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>21</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.886904</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>22</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.871937</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>23</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.852121</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>24</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.835691</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>25</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.816944</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>26</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.802648</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>27</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.793640</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>28</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.778252</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>29</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.767693</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>30</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.757722</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>31</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.751569</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>32</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.743060</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>33</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.735530</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>34</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.728663</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>35</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.722049</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>36</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.715355</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='550' max='550' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [550/550 07:06, Epoch 50/50]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.328542</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.314359</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.297669</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.268253</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.247607</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.225565</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.196141</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.185258</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.162513</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.128378</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>11</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.103570</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>12</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.081249</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>13</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.057329</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>14</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.032270</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>No log</td>\n",
              "      <td>1.009215</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>16</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.991285</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>17</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.972464</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>18</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.949159</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>19</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.926883</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>20</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.906142</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>21</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.886904</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>22</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.871937</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>23</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.852121</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>24</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.835691</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>25</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.816944</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>26</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.802648</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>27</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.793640</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>28</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.778252</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>29</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.767693</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>30</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.757722</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>31</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.751569</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>32</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.743060</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>33</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.735530</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>34</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.728663</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>35</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.722049</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>36</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.715355</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>37</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.711861</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>38</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.707336</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>39</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.702474</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>40</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.697628</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>41</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.693139</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>42</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.688997</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>43</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.685936</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>44</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.683043</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>45</td>\n",
              "      <td>No log</td>\n",
              "      <td>0.680715</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>46</td>\n",
              "      <td>0.754400</td>\n",
              "      <td>0.678543</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>47</td>\n",
              "      <td>0.754400</td>\n",
              "      <td>0.677370</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>48</td>\n",
              "      <td>0.754400</td>\n",
              "      <td>0.676950</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>49</td>\n",
              "      <td>0.754400</td>\n",
              "      <td>0.676782</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>50</td>\n",
              "      <td>0.754400</td>\n",
              "      <td>0.676685</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=550, training_loss=0.7252086015181108, metrics={'train_runtime': 428.1857, 'train_samples_per_second': 9.459, 'train_steps_per_second': 1.284, 'total_flos': 5148682444800.0, 'train_loss': 0.7252086015181108, 'epoch': 50.0})"
            ]
          },
          "execution_count": 15,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "\n",
        "\n",
        "from transformers import Trainer, TrainingArguments\n",
        "\n",
        "# Assuming train_dataset and eval_dataset are already defined\n",
        "# and properly formatted (with 'input_ids', 'attention_mask', 'labels')\n",
        "\n",
        "# Assuming tokenizer is already defined from previous code\n",
        "\n",
        "# Define training arguments\n",
        "training_args = TrainingArguments(\n",
        "    output_dir=\"./bert_tiny_trained\",\n",
        "    evaluation_strategy=\"epoch\",\n",
        "    per_device_train_batch_size=8,\n",
        "    per_device_eval_batch_size=8,\n",
        "    num_train_epochs=50,\n",
        "    logging_dir=\"./logs\",\n",
        "    save_total_limit=2,\n",
        "    load_best_model_at_end=True,\n",
        "    save_strategy=\"epoch\",\n",
        "    # Add more arguments as needed\n",
        ")\n",
        "\n",
        "# Define trainer\n",
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=tokenized_datasets[\"train\"],\n",
        "    eval_dataset=tokenized_datasets[\"validation\"], # Use the test dataset for evaluation\n",
        "    tokenizer=tokenizer,\n",
        ")\n",
        "\n",
        "# Train the model\n",
        "trainer.train()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zap2Kqzx8u0E",
        "outputId": "1445025c-854a-436b-ef18-ad2623848842"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting evaluate\n",
            "  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)\n",
            "Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (3.3.2)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from evaluate) (1.26.4)\n",
            "Requirement already satisfied: dill in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.3.8)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.2.2)\n",
            "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.32.3)\n",
            "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.11/dist-packages (from evaluate) (4.67.1)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from evaluate) (3.5.0)\n",
            "Requirement already satisfied: multiprocess in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.70.16)\n",
            "Requirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (2024.10.0)\n",
            "Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.28.1)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from evaluate) (24.2)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (3.17.0)\n",
            "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (18.1.0)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (3.11.13)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (6.0.2)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.12.2)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (3.4.1)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (2.3.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (2025.1.31)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2025.1)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2025.1)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (2.4.6)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.2)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (25.1.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.5.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.1.0)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (0.3.0)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.18.3)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas->evaluate) (1.17.0)\n",
            "Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.0/84.0 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: evaluate\n",
            "Successfully installed evaluate-0.4.3\n"
          ]
        }
      ],
      "source": [
        "!pip install evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 381,
          "referenced_widgets": [
            "a9830b03d98f49bdacf7e37c303c8290",
            "8acf79f3c6be477caf06fc832797467b",
            "0dd4cd9d03c64faf94162e52e40bcb3d",
            "6fb6f2643bfe438b862e54b949c87495",
            "0da9fa8b44a947e8a164ab65a691ce62",
            "ba50f0a946934d4ab18ce37c91f2bdf5",
            "a89ac7b354a04a10b7bd76a2bf0d3154",
            "5b9f70336c3d4a298803cfdade18adf6",
            "b62115d7bd6944719583c2b214967c4f",
            "123f574c83b14d268c1296dd395a8165",
            "6df772bbfea74b85ae2ed2abc00c6bf7"
          ]
        },
        "id": "22PRPzZr7NTR",
        "outputId": "256bbb3f-e0ad-44cb-f8f5-bd9390cf9d10"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a9830b03d98f49bdacf7e37c303c8290",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading builder script:   0%|          | 0.00/6.79k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{'f1': 0.7425641025641025}\n",
            "              precision    recall  f1-score   support\n",
            "\n",
            "           0       0.67      1.00      0.80         4\n",
            "           1       1.00      0.67      0.80         3\n",
            "           2       0.00      0.00      0.00         2\n",
            "           3       0.86      1.00      0.92         6\n",
            "\n",
            "    accuracy                           0.80        15\n",
            "   macro avg       0.63      0.67      0.63        15\n",
            "weighted avg       0.72      0.80      0.74        15\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
            "/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
            "/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
          ]
        }
      ],
      "source": [
        "# prompt: write a code to test the model using trainer and test dataset also print  metrices  there is not any accuracy\n",
        "import torch\n",
        "from transformers import BertForSequenceClassification, BertTokenizer\n",
        "from torch.utils.data import DataLoader\n",
        "import numpy as np\n",
        "# from datasets import load_metric  # This line caused the issue\n",
        "from evaluate import load # Import load from evaluate\n",
        "\n",
        "# Load the trained model and tokenizer\n",
        "model = BertForSequenceClassification.from_pretrained(\"bert_tiny_trained\")\n",
        "tokenizer = BertTokenizer.from_pretrained(\"bert_tiny_trained\")\n",
        "\n",
        "# Move the model to the appropriate device\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "model.to(device)\n",
        "\n",
        "\n",
        "# Prediction loop\n",
        "model.eval()\n",
        "predictions = []\n",
        "labels = []\n",
        "\n",
        "for batch in test_dataloader:\n",
        "    batch = {k: v.to(device) for k, v in batch.items()}\n",
        "    with torch.no_grad():\n",
        "        outputs = model(**batch)\n",
        "    logits = outputs.logits\n",
        "    predicted_class = np.argmax(logits.cpu().numpy(), axis=1)  # Get the predicted class\n",
        "    predictions.extend(predicted_class)\n",
        "    labels.extend(batch[\"labels\"].cpu().numpy())\n",
        "\n",
        "# Calculate metrics (example: precision, recall, F1-score)\n",
        "# Changed to use 'evaluate' library\n",
        "metric = load(\"f1\") # or any other relevant metric\n",
        "metric.add_batch(predictions=predictions, references=labels)\n",
        "metrics = metric.compute(average=\"weighted\") #weighted average for multiple labels\n",
        "\n",
        "# Print the metrics\n",
        "print(metrics)\n",
        "\n",
        "\n",
        "# Example for calculating other metrics using scikit-learn:\n",
        "from sklearn.metrics import classification_report\n",
        "print(classification_report(labels, predictions))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SOUw5OTmwWZC"
      },
      "outputs": [],
      "source": [
        "trainer.save_model()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 156,
          "referenced_widgets": [
            "ccbf637fa33c431e9fb7d8ba62764b98",
            "7917bba39d1942a39b85bca2164a23d3",
            "1e855602d6db4c31949f6b027712e01f",
            "285cd8c2ca574d47b88632ace3733412",
            "0c3d0e66e13843e8bba327b9dc6542f6",
            "2bfc066bb4164f95b81f2dc83a3cdf8d",
            "8997a806fc374cbbad612365c56f6bb0",
            "afcb5f0fdbf3470fbde672f13e16d6f4",
            "9e888173851f45188cc97f3783ab1996",
            "de028d58c5004675b0c9e57f1c66e46b",
            "175edb8ed27f406eb4d67f6f92595f0b"
          ]
        },
        "id": "-Z3VI6dDwnlJ",
        "outputId": "1442714d-a9ed-489b-fb4f-3aac4e2b1f71"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ccbf637fa33c431e9fb7d8ba62764b98",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/17.6M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "CommitInfo(commit_url='https://huggingface.co/mhassanch/bert_tiny_trained/commit/c3b2084fd44ac65bb2f435c33143f2c5b4a71949', commit_message='End of training', commit_description='', oid='c3b2084fd44ac65bb2f435c33143f2c5b4a71949', pr_url=None, repo_url=RepoUrl('https://huggingface.co/mhassanch/bert_tiny_trained', endpoint='https://huggingface.co', repo_type='model', repo_id='mhassanch/bert_tiny_trained'), pr_revision=None, pr_num=None)"
            ]
          },
          "execution_count": 25,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.push_to_hub()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ta8CSCwTw4Fv",
        "outputId": "bcf4d3b8-5b58-47d6-835b-1330f14185f5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting optimum\n",
            "  Downloading optimum-1.24.0-py3-none-any.whl.metadata (21 kB)\n",
            "Requirement already satisfied: transformers>=4.29 in /usr/local/lib/python3.11/dist-packages (from optimum) (4.48.3)\n",
            "Requirement already satisfied: torch>=1.11 in /usr/local/lib/python3.11/dist-packages (from optimum) (2.5.1+cu124)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from optimum) (24.2)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from optimum) (1.26.4)\n",
            "Requirement already satisfied: huggingface-hub>=0.8.0 in /usr/local/lib/python3.11/dist-packages (from optimum) (0.28.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.8.0->optimum) (3.17.0)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.8.0->optimum) (2024.10.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.8.0->optimum) (6.0.2)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.8.0->optimum) (2.32.3)\n",
            "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.8.0->optimum) (4.67.1)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.8.0->optimum) (4.12.2)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (3.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (3.1.5)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (12.4.127)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (12.4.127)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (12.4.127)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (9.1.0.70)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (12.4.5.8)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (11.2.1.3)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (10.3.5.147)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (11.6.1.9)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (12.3.1.170)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (2.21.5)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (12.4.127)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (12.4.127)\n",
            "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (3.1.0)\n",
            "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=1.11->optimum) (1.13.1)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=1.11->optimum) (1.3.0)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers>=4.29->optimum) (2024.11.6)\n",
            "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers>=4.29->optimum) (0.21.0)\n",
            "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers>=4.29->optimum) (0.5.3)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=1.11->optimum) (3.0.2)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.8.0->optimum) (3.4.1)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.8.0->optimum) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.8.0->optimum) (2.3.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.8.0->optimum) (2025.1.31)\n",
            "Downloading optimum-1.24.0-py3-none-any.whl (433 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m433.6/433.6 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: optimum\n",
            "Successfully installed optimum-1.24.0\n"
          ]
        }
      ],
      "source": [
        "!pip install optimum"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rckQJQqnMiEt",
        "outputId": "45cf317d-9da7-4c92-ab05-8bd8cb5d1c85"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting onnxruntime\n",
            "  Downloading onnxruntime-1.20.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)\n",
            "Collecting coloredlogs (from onnxruntime)\n",
            "  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)\n",
            "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.11/dist-packages (from onnxruntime) (25.2.10)\n",
            "Requirement already satisfied: numpy>=1.21.6 in /usr/local/lib/python3.11/dist-packages (from onnxruntime) (1.26.4)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from onnxruntime) (24.2)\n",
            "Requirement already satisfied: protobuf in /usr/local/lib/python3.11/dist-packages (from onnxruntime) (4.25.6)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from onnxruntime) (1.13.1)\n",
            "Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)\n",
            "  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy->onnxruntime) (1.3.0)\n",
            "Downloading onnxruntime-1.20.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (13.3 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.3/13.3 MB\u001b[0m \u001b[31m88.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: humanfriendly, coloredlogs, onnxruntime\n",
            "Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnxruntime-1.20.1\n"
          ]
        }
      ],
      "source": [
        "!pip install onnxruntime"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Sinap317wvis",
        "outputId": "354eeb0a-b363-4c70-decf-ad72a38eaefe"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2025-03-05 12:06:12.380532: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
            "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
            "E0000 00:00:1741176372.405574   23097 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
            "E0000 00:00:1741176372.412941   23097 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
          ]
        }
      ],
      "source": [
        "!optimum-cli export onnx --model mhassanch/bert_tiny_trained bert_tiny_trained_onnx/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9GbTbyY1xuZz",
        "outputId": "efc01bdd-e2c5-485a-9768-8fa300fdad0b"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:root:Please consider to run pre-processing before quantization. Refer to example: https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/cpu/ReadMe.md \n"
          ]
        }
      ],
      "source": [
        "import onnx\n",
        "from onnxruntime.quantization import quantize_dynamic, QuantType\n",
        "\n",
        "model_fp32 = '/content/bert_tiny_trained_onnx/model.onnx'\n",
        "model_quant = '/content/bert_tiny_trained_onnx/model.quant.onnx'\n",
        "quantized_model = quantize_dynamic(model_fp32, model_quant)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 408
        },
        "id": "ltTtezHwM4l8",
        "outputId": "d81fe181-3a43-4afe-bf3a-bf9cdc2a8302"
      },
      "outputs": [
        {
          "ename": "ValueError",
          "evalue": "Required inputs (['token_type_ids']) are missing from input feed (['input_ids', 'attention_mask']).",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-48-816060ca5465>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[0;31m# Run inference\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexample_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m \u001b[0;31m# Process the outputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, output_names, input_feed, run_options)\u001b[0m\n\u001b[1;32m    260\u001b[0m             \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0moutput_name\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0minput_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    261\u001b[0m         \"\"\"\n\u001b[0;32m--> 262\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_feed\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    263\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0moutput_names\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    264\u001b[0m             \u001b[0moutput_names\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_outputs_meta\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py\u001b[0m in \u001b[0;36m_validate_input\u001b[0;34m(self, feed_input_names)\u001b[0m\n\u001b[1;32m    242\u001b[0m                 \u001b[0mmissing_input_names\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    243\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mmissing_input_names\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 244\u001b[0;31m             raise ValueError(\n\u001b[0m\u001b[1;32m    245\u001b[0m                 \u001b[0;34mf\"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names}).\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    246\u001b[0m             )\n",
            "\u001b[0;31mValueError\u001b[0m: Required inputs (['token_type_ids']) are missing from input feed (['input_ids', 'attention_mask'])."
          ]
        }
      ],
      "source": [
        "# prompt: write a code to test the onnx quantized model\n",
        "\n",
        "import onnxruntime as ort\n",
        "import numpy as np\n",
        "\n",
        "# Load the quantized ONNX model\n",
        "sess = ort.InferenceSession(\"/content/bert_tiny_trained_onnx/model.quant.onnx\")\n",
        "\n",
        "# Get input names\n",
        "input_names = [input.name for input in sess.get_inputs()]\n",
        "\n",
        "# Example input data (replace with your actual data)\n",
        "# Make sure the input data matches the expected shape and type\n",
        "example_input = {\n",
        "    'input_ids': np.random.randint(0, 100, size=(1, 128), dtype=np.int64),  # Example input_ids\n",
        "    'attention_mask': np.ones((1, 128), dtype=np.int64)  # Example attention_mask\n",
        "}\n",
        "\n",
        "# Run inference\n",
        "outputs = sess.run(None, example_input)\n",
        "\n",
        "# Process the outputs\n",
        "# The output format might be different depending on your model\n",
        "predicted_class = np.argmax(outputs[0], axis=1)\n",
        "\n",
        "print(f\"Predicted Class: {predicted_class}\")\n",
        "\n",
        "#Further evaluation can be performed by comparing the predictions\n",
        "#with the ground truth labels of a test dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xNCXZm5Nu1LT"
      },
      "outputs": [],
      "source": [
        "!pip install onnx onnxruntime"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dH_p9yYs-QhO"
      },
      "outputs": [],
      "source": [
        "test_queries = [\n",
        "    'Detect cars in this area',\n",
        "    'Mask the green field in this area',\n",
        "    'Shade the rooftop in this residental areas',\n",
        "    'Highlight the roadways in the map',\n",
        "    'Identify the airplanes on the airstrip'\n",
        "    ]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "riqfWl_dyGDi",
        "outputId": "7d0ef22f-fd79-4aa8-dfee-bb53d6a0bfe8"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Query: Detect cars in this area\n",
            "Prediction: object-detection:geobase/WALDO30_yolov8m_640x640\n",
            "--------------------\n",
            "Query: Mask the green field in this area\n",
            "Prediction: mask-generation:Xenova/slimsam-77-uniform\n",
            "--------------------\n",
            "Query: Shade the rooftop in this residental areas\n",
            "Prediction: mask-generation:Xenova/slimsam-77-uniform\n",
            "--------------------\n",
            "Query: Highlight the roadways in the map\n",
            "Prediction: mask-generation:Xenova/slimsam-77-uniform\n",
            "--------------------\n",
            "Query: Identify the airplanes on the airstrip\n",
            "Prediction: object-detection:geobase/WALDO30_yolov8m_640x640\n",
            "--------------------\n"
          ]
        }
      ],
      "source": [
        "from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline\n",
        "\n",
        "label_mapping = {\n",
        "    0: \"object-detection:geobase/WALDO30_yolov8m_640x640\",\n",
        "    1: \"zero-shot-object-detection:onnx-community/grounding-dino-tiny-ONNX\",\n",
        "    2: \"zero-shot-object-detection:Xenova/owlvit-base-patch32\",\n",
        "    3: \"mask-generation:Xenova/slimsam-77-uniform\"    # For queries that start with \"Segment\"\n",
        "}\n",
        "\n",
        "# Load the model with explicit label mapping\n",
        "model = AutoModelForSequenceClassification.from_pretrained(\n",
        "    \"bert_tiny_trained\",  # Replace with your saved model path\n",
        "    num_labels=4,\n",
        "    id2label=label_mapping,\n",
        "    label2id={v: k for k, v in label_mapping.items()}\n",
        ")\n",
        "\n",
        "# Load tokenizer\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"mhassanch/bert_tiny_trained\")\n",
        "\n",
        "for query in test_queries:\n",
        "    inputs = tokenizer(query, padding=True, truncation=True, return_tensors=\"pt\")\n",
        "    output = model(**inputs)\n",
        "\n",
        "    # Get predicted class\n",
        "    probabilities = torch.softmax(output.logits, dim=1)  # Apply softmax\n",
        "    predicted_class_index = torch.argmax(probabilities, dim=1).item()  # Get index of highest probability\n",
        "    predicted_class_label = model.config.id2label[predicted_class_index]  # Get label from index\n",
        "\n",
        "    print(f\"Query: {query}\")\n",
        "    print(f\"Prediction: {predicted_class_label}\")\n",
        "    print(\"-\" * 20)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "097xEzLQze5B"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "include_colab_link": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "0c3d0e66e13843e8bba327b9dc6542f6": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0da9fa8b44a947e8a164ab65a691ce62": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0dd4cd9d03c64faf94162e52e40bcb3d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5b9f70336c3d4a298803cfdade18adf6",
            "max": 6785,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_b62115d7bd6944719583c2b214967c4f",
            "value": 6785
          }
        },
        "123f574c83b14d268c1296dd395a8165": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "175edb8ed27f406eb4d67f6f92595f0b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "1e855602d6db4c31949f6b027712e01f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_afcb5f0fdbf3470fbde672f13e16d6f4",
            "max": 17550344,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_9e888173851f45188cc97f3783ab1996",
            "value": 17550344
          }
        },
        "285cd8c2ca574d47b88632ace3733412": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_de028d58c5004675b0c9e57f1c66e46b",
            "placeholder": "​",
            "style": "IPY_MODEL_175edb8ed27f406eb4d67f6f92595f0b",
            "value": " 17.6M/17.6M [00:00&lt;00:00, 29.7MB/s]"
          }
        },
        "2bfc066bb4164f95b81f2dc83a3cdf8d": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5b9f70336c3d4a298803cfdade18adf6": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6df772bbfea74b85ae2ed2abc00c6bf7": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "6fb6f2643bfe438b862e54b949c87495": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_123f574c83b14d268c1296dd395a8165",
            "placeholder": "​",
            "style": "IPY_MODEL_6df772bbfea74b85ae2ed2abc00c6bf7",
            "value": " 6.79k/6.79k [00:00&lt;00:00, 117kB/s]"
          }
        },
        "7917bba39d1942a39b85bca2164a23d3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2bfc066bb4164f95b81f2dc83a3cdf8d",
            "placeholder": "​",
            "style": "IPY_MODEL_8997a806fc374cbbad612365c56f6bb0",
            "value": "model.safetensors: 100%"
          }
        },
        "8997a806fc374cbbad612365c56f6bb0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8acf79f3c6be477caf06fc832797467b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ba50f0a946934d4ab18ce37c91f2bdf5",
            "placeholder": "​",
            "style": "IPY_MODEL_a89ac7b354a04a10b7bd76a2bf0d3154",
            "value": "Downloading builder script: 100%"
          }
        },
        "9e888173851f45188cc97f3783ab1996": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "a89ac7b354a04a10b7bd76a2bf0d3154": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a9830b03d98f49bdacf7e37c303c8290": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_8acf79f3c6be477caf06fc832797467b",
              "IPY_MODEL_0dd4cd9d03c64faf94162e52e40bcb3d",
              "IPY_MODEL_6fb6f2643bfe438b862e54b949c87495"
            ],
            "layout": "IPY_MODEL_0da9fa8b44a947e8a164ab65a691ce62"
          }
        },
        "afcb5f0fdbf3470fbde672f13e16d6f4": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b62115d7bd6944719583c2b214967c4f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "ba50f0a946934d4ab18ce37c91f2bdf5": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ccbf637fa33c431e9fb7d8ba62764b98": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_7917bba39d1942a39b85bca2164a23d3",
              "IPY_MODEL_1e855602d6db4c31949f6b027712e01f",
              "IPY_MODEL_285cd8c2ca574d47b88632ace3733412"
            ],
            "layout": "IPY_MODEL_0c3d0e66e13843e8bba327b9dc6542f6"
          }
        },
        "de028d58c5004675b0c9e57f1c66e46b": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
